{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GAN Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from models.GAN import GAN\n",
    "from utils.loaders import load_safari\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run params\n",
    "SECTION = 'gan'\n",
    "RUN_ID = '0001'\n",
    "DATA_NAME = 'camel'\n",
    "RUN_FOLDER = 'run/{}/'.format(SECTION)\n",
    "RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])\n",
    "\n",
    "if not os.path.exists(RUN_FOLDER):\n",
    "    os.mkdir(RUN_FOLDER)\n",
    "    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))\n",
    "    os.mkdir(os.path.join(RUN_FOLDER, 'images'))\n",
    "    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))\n",
    "\n",
    "mode =  'build' #'load' #"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_train, y_train) = load_safari(DATA_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(x_train[200,:,:,0], cmap = 'gray')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gan = GAN(input_dim = (28,28,1)\n",
    "        , discriminator_conv_filters = [64,64,128,128]\n",
    "        , discriminator_conv_kernel_size = [5,5,5,5]\n",
    "        , discriminator_conv_strides = [2,2,2,1]\n",
    "        , discriminator_batch_norm_momentum = None\n",
    "        , discriminator_activation = 'relu'\n",
    "        , discriminator_dropout_rate = 0.4\n",
    "        , discriminator_learning_rate = 0.0008\n",
    "        , generator_initial_dense_layer_size = (7, 7, 64)\n",
    "        , generator_upsample = [2,2, 1, 1]\n",
    "        , generator_conv_filters = [128,64, 64,1]\n",
    "        , generator_conv_kernel_size = [5,5,5,5]\n",
    "        , generator_conv_strides = [1,1, 1, 1]\n",
    "        , generator_batch_norm_momentum = 0.9\n",
    "        , generator_activation = 'relu'\n",
    "        , generator_dropout_rate = None\n",
    "        , generator_learning_rate = 0.0004\n",
    "        , optimiser = 'rmsprop'\n",
    "        , z_dim = 100\n",
    "        )\n",
    "\n",
    "if mode == 'build':\n",
    "    gan.save(RUN_FOLDER)\n",
    "else:\n",
    "    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gan.discriminator.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gan.generator.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 64\n",
    "EPOCHS = 6000\n",
    "PRINT_EVERY_N_BATCHES = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "gan.train(     \n",
    "    x_train\n",
    "    , batch_size = BATCH_SIZE\n",
    "    , epochs = EPOCHS\n",
    "    , run_folder = RUN_FOLDER\n",
    "    , print_every_n_batches = PRINT_EVERY_N_BATCHES\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure()\n",
    "plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)\n",
    "\n",
    "plt.plot([x[1] for x in gan.d_losses], color='green', linewidth=0.25)\n",
    "plt.plot([x[2] for x in gan.d_losses], color='red', linewidth=0.25)\n",
    "plt.plot([x[0] for x in gan.g_losses], color='orange', linewidth=0.25)\n",
    "\n",
    "plt.xlabel('batch', fontsize=18)\n",
    "plt.ylabel('loss', fontsize=16)\n",
    "\n",
    "plt.xlim(0, 2000)\n",
    "plt.ylim(0, 2)\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure()\n",
    "plt.plot([x[3] for x in gan.d_losses], color='black', linewidth=0.25)\n",
    "plt.plot([x[4] for x in gan.d_losses], color='green', linewidth=0.25)\n",
    "plt.plot([x[5] for x in gan.d_losses], color='red', linewidth=0.25)\n",
    "plt.plot([x[1] for x in gan.g_losses], color='orange', linewidth=0.25)\n",
    "\n",
    "plt.xlabel('batch', fontsize=18)\n",
    "plt.ylabel('accuracy', fontsize=16)\n",
    "\n",
    "plt.xlim(0, 2000)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gdl_code",
   "language": "python",
   "name": "gdl_code"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
